{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0deaff3f-56ae-4a74-9db0-44d6a25e9490",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-246-g3d31191b-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from IPython import display\n",
    "from d2l import torch as d2l\n",
    "def load_data_fashion_mnist(batch_size,resize=None):\n",
    "    trans = [transforms.ToTensor()]\n",
    "    if resize:\n",
    "        trans.insert(0,transforms.Resize(resize))\n",
    "    trans = transforms.Compose(trans)\n",
    "    mnist_train = torchvision.datasets.FashionMNIST(root='../data',train=True,transform=trans,download=True)\n",
    "    mnist_test = torchvision.datasets.FashionMNIST(root='../data',train=False,transform=trans,download=True)\n",
    "    return (data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers =get_dataloader_workers()),\n",
    "    data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "394618de-f801-4046-b141-bf894e9e2b34",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3dcb118a-6295-46bf-a85d-303fda6f9457",
   "metadata": {},
   "outputs": [],
   "source": [
    "#初始化模型参数\n",
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "W = torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)\n",
    "b = torch.zeros(num_outputs,requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d253505d-2f97-4c15-90ae-519ddee1f950",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[5., 7., 9.]]),\n",
       " tensor([[ 6.],\n",
       "         [15.]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义softmax操作\n",
    "x = torch.tensor([[1.0,2.0,3.0],[4.0,5.0,6.0]])\n",
    "x.sum(0,keepdim=True),x.sum(1,keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "aabc168a-60cb-4f2d-98f5-670b4856d4e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(x):\n",
    "    X_exp =  torch.exp(x)\n",
    "    partition = X_exp.sum(1,keepdim=True)\n",
    "    return X_exp /partition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2ae7e0b2-32a2-4468-9fa5-d3c60e9ebcb1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.1397, 0.4685, 0.1552, 0.0955, 0.1412],\n",
       "         [0.0232, 0.0242, 0.6269, 0.1480, 0.1777]]),\n",
       " tensor([1.0000, 1.0000]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.normal(0,1,(2,5))\n",
    "x_prob = softmax(x)\n",
    "x_prob ,x_prob.sum(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e4ce0f14-564e-42e3-bca8-df500e1a1c1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义模型\n",
    "def net(x):\n",
    "    return softmax(torch.matmul(x.reshape((-1,W.shape[0])),W) +b )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0b585f7c-44d1-49bd-96d8-a4cb247ed925",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1000, 0.5000])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义损失函数\n",
    "y = torch.tensor([0,2])\n",
    "y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])\n",
    "y_hat[[0,1],y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b4a53a7c-a791-457b-a779-659eba0ed124",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.3026, 0.6931])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cross_entropy(y_hat,y):\n",
    "    return - torch.log(y_hat[range(len(y_hat)),y])\n",
    "cross_entropy(y_hat,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7680a0bb-81f4-4456-b180-6eff512d3045",
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(y_hat,y):\n",
    "    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
    "        y_hat = y_hat.argmax(axis=1)\n",
    "    cmp = y_hat.type(y.dtype) == y\n",
    "    return float(cmp.type(y.dtype).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2dd6e08f-674e-4b55-90ce-cf50b68e3e90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy(y_hat,y) / len(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "77bfd7d1-7280-41f0-a853-717183a30dcd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1206"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Accumulator:\n",
    "    def __init__(self,n):\n",
    "        self.data = [0.0] * n\n",
    "\n",
    "    def add(self,*args):\n",
    "        self.data = [a + float(b) for a,b in zip(self.data,args)]\n",
    "\n",
    "    def reset(self):\n",
    "        self.data =[0.0] * len(self.data)\n",
    "\n",
    "    def __getitem__(self,idx):\n",
    "        return self.data[idx]\n",
    "\n",
    "\n",
    "def evaluate_accuracy(net,data_iter):\n",
    "    if isinstance(net,torch.nn.Module):\n",
    "        net.eval()\n",
    "    metric = Accumulator(2)\n",
    "    with torch.no_grad():\n",
    "        for x,y in data_iter:\n",
    "            metric.add(accuracy(net(x),y),y.numel())\n",
    "        return metric[0] /metric[1]\n",
    "\n",
    "evaluate_accuracy(net,test_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5ddd03dd-ceea-4d52-acee-0d620c44ad72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练\n",
    "def train_epoch_ch3(net,train_iter,loss,updater):\n",
    "    if isinstance(net,torch.nn.Module):\n",
    "        net.train()\n",
    "\n",
    "    metric = Accumulator(3)\n",
    "    for x,y in train_iter:\n",
    "        y_hat = net(x)\n",
    "        l = loss(y_hat,y)\n",
    "        if isinstance(updater,torch.optim.Optimizer):\n",
    "            updater.zero_grad()\n",
    "            l.mean().backward()\n",
    "            updater.step()\n",
    "        else:\n",
    "            l.sum().backward()\n",
    "            updater(x.shape[0])\n",
    "        metric.add(float(l.sum()),accuracy(y_hat,y),y.numer())\n",
    "    return metric[0]/metric[2],metric[1]/metric[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3780763-88a1-411e-b0ba-d23ecd122a21",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
